# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import inspect
from collections import defaultdict
from typing import Any, List, Optional, Tuple

import torch.nn as nn

from nni.common.serializer import is_traceable
from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node, Evaluator
from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from nni.retiarii.utils import uid

from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
from .component import Repeat, NasBench101Cell, NasBench101Mutator


class LayerChoiceMutator(Mutator):
    def __init__(self, nodes: List[Node]):
        super().__init__(label=nodes[0].operation.parameters['label'])
        self.nodes = nodes

    def mutate(self, model):
        candidates = self.nodes[0].operation.parameters['candidates']
        chosen = self.choice(candidates)
        for node in self.nodes:
            # Each layer choice corresponds to a cell, which is unconnected in the base graph.
            # We add the connections here in the mutation logic.
            # Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
            target = model.graphs[node.operation.cell_name]
            chosen_node = target.get_node_by_name(chosen)
            assert chosen_node is not None
            target.add_edge((target.input_node, 0), (chosen_node, None))
            target.add_edge((chosen_node, None), (target.output_node, None))
            model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))

            # remove redundant nodes
            for rm_node in list(target.hidden_nodes):  # remove from a list on the fly will cause issues
                if rm_node.name != chosen_node.name:
                    rm_node.remove()


class InputChoiceMutator(Mutator):
    def __init__(self, nodes: List[Node]):
        super().__init__(label=nodes[0].operation.parameters['label'])
        self.nodes = nodes

    def mutate(self, model):
        n_candidates = self.nodes[0].operation.parameters['n_candidates']
        n_chosen = self.nodes[0].operation.parameters['n_chosen']
        candidates = list(range(n_candidates))
        if n_chosen is None:
            chosen = [i for i in candidates if self.choice([False, True])]
            # FIXME This is a hack to make choice align with the previous format
            self._cur_samples = chosen
        else:
            chosen = [self.choice(candidates) for _ in range(n_chosen)]
        for node in self.nodes:
            target = model.get_node_by_name(node.name)
            target.update_operation('__torch__.nni.retiarii.nn.pytorch.ChosenInputs',
                                    {'chosen': chosen, 'reduction': node.operation.parameters['reduction']})


class ValueChoiceMutator(Mutator):
    def __init__(self, nodes: List[Node], candidates: List[Any]):
        super().__init__(label=nodes[0].operation.parameters['label'])
        self.nodes = nodes
        self.candidates = candidates

    def mutate(self, model):
        chosen = self.choice(self.candidates)
        for node in self.nodes:
            target = model.get_node_by_name(node.name)
            target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})


class ParameterChoiceMutator(Mutator):
    def __init__(self, nodes: List[Tuple[Node, str]], candidates: List[Any]):
        node, argname = nodes[0]
        super().__init__(label=node.operation.parameters[argname].label)
        self.nodes = nodes
        self.candidates = candidates

    def mutate(self, model):
        chosen = self.choice(self.candidates)
        for node, argname in self.nodes:
            chosen_value = node.operation.parameters[argname].access(chosen)
            target = model.get_node_by_name(node.name)
            target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen_value})


class RepeatMutator(Mutator):
    def __init__(self, nodes: List[Node]):
        # nodes is a subgraph consisting of repeated blocks.
        super().__init__(label=nodes[0].operation.parameters['label'])
        self.nodes = nodes

    def _retrieve_chain_from_graph(self, graph: Graph) -> List[Node]:
        u = graph.input_node
        chain = []
        while u != graph.output_node:
            if u != graph.input_node:
                chain.append(u)
            assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successor}.'
            u = u.successors[0]
        return chain

    def mutate(self, model):
        min_depth = self.nodes[0].operation.parameters['min_depth']
        max_depth = self.nodes[0].operation.parameters['max_depth']
        if min_depth < max_depth:
            chosen_depth = self.choice(list(range(min_depth, max_depth + 1)))
        for node in self.nodes:
            # the logic here is similar to layer choice. We find cell attached to each node.
            target: Graph = model.graphs[node.operation.cell_name]
            chain = self._retrieve_chain_from_graph(target)
            for edge in chain[chosen_depth - 1].outgoing_edges:
                edge.remove()
            target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None))
            for rm_node in chain[chosen_depth:]:
                for edge in rm_node.outgoing_edges:
                    edge.remove()
                rm_node.remove()
            # to delete the unused parameters.
            model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))


def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
    applied_mutators = []

    ic_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.api.InputChoice'))
    for node_list in ic_nodes:
        assert _is_all_equal(map(lambda node: node.operation.parameters['n_candidates'], node_list)) and \
            _is_all_equal(map(lambda node: node.operation.parameters['n_chosen'], node_list)), \
            'Input choice with the same label must have the same number of candidates.'
        mutator = InputChoiceMutator(node_list)
        applied_mutators.append(mutator)

    vc_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.api.ValueChoice'))
    for node_list in vc_nodes:
        assert _is_all_equal(map(lambda node: node.operation.parameters['candidates'], node_list)), \
            'Value choice with the same label must have the same candidates.'
        mutator = ValueChoiceMutator(node_list, node_list[0].operation.parameters['candidates'])
        applied_mutators.append(mutator)

    pc_nodes = []
    for node in model.get_nodes():
        for name, choice in node.operation.parameters.items():
            if isinstance(choice, ValueChoice):
                pc_nodes.append((node, name))
    pc_nodes = _group_parameters_by_label(pc_nodes)
    for node_list in pc_nodes:
        assert _is_all_equal([node.operation.parameters[name].candidates for node, name in node_list]), \
            'Value choice with the same label must have the same candidates.'
        first_node, first_argname = node_list[0]
        mutator = ParameterChoiceMutator(node_list, first_node.operation.parameters[first_argname].candidates)
        applied_mutators.append(mutator)

    # apply layer choice at last as it will delete some nodes
    lc_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'layerchoice',
                                      model.get_nodes_by_type('_cell')))
    for node_list in lc_nodes:
        assert _is_all_equal(map(lambda node: len(node.operation.parameters['candidates']), node_list)), \
            'Layer choice with the same label must have the same number of candidates.'
        mutator = LayerChoiceMutator(node_list)
        applied_mutators.append(mutator)

    repeat_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'repeat',
                                          model.get_nodes_by_type('_cell')))
    for node_list in repeat_nodes:
        assert _is_all_equal(map(lambda node: node.operation.parameters['max_depth'], node_list)) and \
            _is_all_equal(map(lambda node: node.operation.parameters['min_depth'], node_list)), \
            'Repeat with the same label must have the same number of candidates.'
        mutator = RepeatMutator(node_list)
        applied_mutators.append(mutator)

    if applied_mutators:
        return applied_mutators
    return None


# The following are written for pure-python mode


class ManyChooseManyMutator(Mutator):
    """
    Choose based on labels. Will not affect the model itself.
    """

    def __init__(self, label: Optional[str]):
        super().__init__(label=label)

    @staticmethod
    def candidates(node):
        if 'n_candidates' in node.operation.parameters:
            return list(range(node.operation.parameters['n_candidates']))
        else:
            return node.operation.parameters['candidates']

    @staticmethod
    def number_of_chosen(node):
        if 'n_chosen' in node.operation.parameters:
            return node.operation.parameters['n_chosen']
        return 1

    def mutate(self, model: Model):
        # this mutate does not have any effect, but it is recorded in the mutation history
        for node in model.get_nodes_by_label(self.label):
            n_chosen = self.number_of_chosen(node)
            if n_chosen is None:
                candidates = [i for i in self.candidates(node) if self.choice([False, True])]
                # FIXME This is a hack to make choice align with the previous format
                # For example, it will convert [False, True, True] into [1, 2].
                self._cur_samples = candidates
            else:
                for _ in range(n_chosen):
                    self.choice(self.candidates(node))
            break


def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Optional[List[Mutator]]]:
    model = Model(_internal=True)
    graph = Graph(model, uid(), '_model', _internal=True)._register()
    model.python_class = pytorch_model.__class__
    if len(inspect.signature(model.python_class.__init__).parameters) > 1:
        if not is_model_wrapped(pytorch_model):
            raise ValueError('Please annotate the model with @model_wrapper decorator in python execution mode '
                             'if your model has init parameters.')
        model.python_init_params = pytorch_model.trace_kwargs
    else:
        model.python_init_params = {}

    for name, module in pytorch_model.named_modules():
        # tricky case: value choice that serves as parameters are stored in traced arguments
        if is_basic_unit(module):
            for key, value in module.trace_kwargs.items():
                if isinstance(value, ValueChoice):
                    node = graph.add_node(name + '.init.' + key, 'ValueChoice', {'candidates': value.candidates})
                    node.label = value.label

        if isinstance(module, (LayerChoice, InputChoice, ValueChoice)):
            # TODO: check the label of module and warn if it's auto-generated
            pass
        if isinstance(module, LayerChoice):
            node = graph.add_node(name, 'LayerChoice', {'candidates': module.names})
            node.label = module.label
        if isinstance(module, InputChoice):
            node = graph.add_node(name, 'InputChoice',
                                  {'n_candidates': module.n_candidates, 'n_chosen': module.n_chosen})
            node.label = module.label
        if isinstance(module, ValueChoice):
            node = graph.add_node(name, 'ValueChoice', {'candidates': module.candidates})
            node.label = module.label
        if isinstance(module, Repeat) and module.min_depth <= module.max_depth:
            node = graph.add_node(name, 'Repeat', {
                'candidates': list(range(module.min_depth, module.max_depth + 1))
            })
            node.label = module.label
        if isinstance(module, NasBench101Cell):
            node = graph.add_node(name, 'NasBench101Cell', {
                'max_num_edges': module.max_num_edges
            })
            node.label = module.label
        if isinstance(module, Placeholder):
            raise NotImplementedError('Placeholder is not supported in python execution mode.')

    model.status = ModelStatus.Frozen
    if not graph.hidden_nodes:
        return model, None

    mutators = []
    mutators_final = []
    for nodes in _group_by_label_and_type(graph.hidden_nodes):
        assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \
            f'Node with label "{nodes[0].label}" does not all have the same type.'
        assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \
            f'Node with label "{nodes[0].label}" does not agree on parameters.'
        if nodes[0].operation.type == 'NasBench101Cell':
            mutators_final.append(NasBench101Mutator(nodes[0].label))
        else:
            mutators.append(ManyChooseManyMutator(nodes[0].label))
    return model, mutators + mutators_final


# mutations for evaluator

class EvaluatorValueChoiceMutator(Mutator):
    def __init__(self, keys: List[str], label: Optional[str]):
        self.keys = keys
        super().__init__(label=label)

    def mutate(self, model: Model):
        # make a copy to mutate the evaluator
        model.evaluator = model.evaluator.trace_copy()
        chosen = None
        for i, key in enumerate(self.keys):
            value_choice: ValueChoice = model.evaluator.trace_kwargs[key]
            if i == 0:
                # i == 0 is needed here because there can be candidates of "None"
                chosen = self.choice(value_choice.candidates)
            # get the real chosen value after "access"
            model.evaluator.trace_kwargs[key] = value_choice.access(chosen)
        return model


def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]:
    # take all the value choice in the kwargs of evaluaator into a list
    if not is_traceable(evaluator):
        return []
    mutator_candidates = {}
    mutator_keys = defaultdict(list)
    for key, param in evaluator.trace_kwargs.items():
        if isinstance(param, ValueChoice):
            # merge duplicate labels
            for mutator in existing_mutators:
                if mutator.name == param.label:
                    raise ValueError(f'Found duplicated labels for mutators {param.label}. When two mutators have the same name, '
                                     'they would share choices. However, sharing choices between model and evaluator is not yet supported.')
            if param.label in mutator_candidates and mutator_candidates[param.label] != param.candidates:
                raise ValueError(f'Duplicate labels for evaluator ValueChoice {param.label}. They should share choices.'
                                 f'But their candidate list is not equal: {mutator_candidates[param.label][1]} vs. {param.candidates}')
            mutator_keys[param.label].append(key)
            mutator_candidates[param.label] = param.candidates
    mutators = []
    for key in mutator_keys:
        mutators.append(EvaluatorValueChoiceMutator(mutator_keys[key], key))
    return mutators


# utility functions


def _is_all_equal(lst):
    last = None
    for x in lst:
        if last is not None and last != x:
            return False
        last = x
    return True


def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
    result = {}
    for node in nodes:
        key = (node.label, node.operation.type)
        if key not in result:
            result[key] = []
        result[key].append(node)
    return list(result.values())


def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
    result = {}
    for node in nodes:
        label = node.operation.parameters['label']
        if label not in result:
            result[label] = []
        result[label].append(node)
    return list(result.values())


def _group_parameters_by_label(nodes: List[Tuple[Node, str]]) -> List[List[Tuple[Node, str]]]:
    result = {}
    for node, argname in nodes:
        label = node.operation.parameters[argname].label
        if label not in result:
            result[label] = []
        result[label].append((node, argname))
    return list(result.values())
